# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Executes Keras benchmarks and accuracy tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import json

from absl import flags
from absl.testing import flagsaver
import tensorflow as tf  # pylint: disable=g-bad-import-order

FLAGS = flags.FLAGS


class KerasBenchmark(tf.test.Benchmark):
  """Base benchmark class with methods to simplify testing."""
  local_flags = None

  def __init__(self, output_dir=None, default_flags=None, flag_methods=None):
    self.output_dir = output_dir
    self.default_flags = default_flags or {}
    self.flag_methods = flag_methods or {}

  def _get_model_dir(self, folder_name):
    return os.path.join(self.output_dir, folder_name)

  def _setup(self):
    """Sets up and resets flags before each test."""
    tf.logging.set_verbosity(tf.logging.DEBUG)
    if KerasBenchmark.local_flags is None:
      for flag_method in self.flag_methods:
        flag_method()
      # Loads flags to get defaults to then override. List cannot be empty.
      flags.FLAGS(['foo'])
      # Overrides flag values with defaults for the class of tests.
      for k, v in self.default_flags.items():
        setattr(FLAGS, k, v)
      saved_flag_values = flagsaver.save_flag_values()
      KerasBenchmark.local_flags = saved_flag_values
    else:
      flagsaver.restore_flag_values(KerasBenchmark.local_flags)

  def _report_benchmark(self,
                        stats,
                        wall_time_sec,
                        top_1_max=None,
                        top_1_min=None,
                        log_steps=None,
                        total_batch_size=None,
                        warmup=1):
    """Report benchmark results by writing to local protobuf file

    Args:
      stats: dict returned from keras models with known entries.
      wall_time_sec: the during of the benchmark execution in seconds
      top_1_max: highest passing level for top_1 accuracy.
      top_1_min: lowest passing level for top_1 accuracy.
      log_steps: How often the log was created for stats['step_timestamp_log'].
      total_batch_size: Global batch-size.
      warmup: number of entries in stats['step_timestamp_log'] to ignore.
    """

    extras = {}
    if 'accuracy_top_1' in stats:
      extras['accuracy'] = self._json_description(
          stats['accuracy_top_1'],
          priority=0,
          min_value=top_1_min,
          max_value=top_1_max)
      extras['top_1_train_accuracy'] = self._json_description(
          stats['training_accuracy_top_1'], priority=1)

    if (warmup and 'step_timestamp_log' in stats and
        len(stats['step_timestamp_log']) > warmup):
      # first entry in the time_log is start of step 1. The rest of the
      # entries are the end of each step recorded
      time_log = stats['step_timestamp_log']
      elapsed = time_log[-1].timestamp - time_log[warmup].timestamp
      num_examples = (
          total_batch_size * log_steps * (len(time_log) - warmup - 1))
      examples_per_sec = num_examples / elapsed
      extras['exp_per_second'] = self._json_description(
          examples_per_sec, priority=2)

    if 'avg_exp_per_second' in stats:
      extras['avg_exp_per_second'] = self._json_description(
          stats['avg_exp_per_second'], priority=3)

    self.report_benchmark(iters=-1, wall_time=wall_time_sec, extras=extras)

  def _json_description(self,
                        value,
                        priority=None,
                        min_value=None,
                        max_value=None):
    """Get a json-formatted string describing the attributes for a metric"""

    attributes = {}
    attributes['value'] = value
    if priority:
      attributes['priority'] = priority
    if min_value:
      attributes['min_value'] = min_value
    if max_value:
      attributes['max_value'] = max_value

    if min_value or max_value:
      succeeded = True
      if min_value and value < min_value:
        succeeded = False
      if max_value and value > max_value:
        succeeded = False
      attributes['succeeded'] = succeeded

    return json.dumps(attributes)
